{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.sparse import csr_matrix\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "#Check Ethan Rosenthal github"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "class tensor_fact(nn.Module):\n",
    "    def __init__(self,n_pat=10,n_meas=5,n_t=25,l_dim=2,n_u=2,n_w=3):\n",
    "        super(tensor_fact,self).__init__()\n",
    "        self.n_pat=n_pat\n",
    "        self.n_meas=n_meas\n",
    "        self.n_t=n_t\n",
    "        self.l_dim=l_dim\n",
    "        self.n_u=n_u\n",
    "        self.n_w=n_w\n",
    "        self.pat_lat=nn.Embedding(n_pat,l_dim) #sparse gradients ?\n",
    "        self.pat_lat.weight=nn.Parameter(torch.randn([n_pat,l_dim]))\n",
    "        self.meas_lat=nn.Embedding(n_meas,l_dim)\n",
    "        self.meas_lat.weight=nn.Parameter(torch.randn([n_meas,l_dim]))\n",
    "        self.time_lat=nn.Embedding(n_t,l_dim).double()\n",
    "        self.beta_u=torch.nn.Parameter(torch.randn([n_u,l_dim],requires_grad=True).double())\n",
    "        self.beta_w=torch.nn.Parameter(torch.randn([n_w,l_dim],requires_grad=True).double())\n",
    "    def forward(self,idx_pat,idx_meas,idx_t,cov_u,cov_w):\n",
    "        pred=((self.pat_lat(idx_pat)+torch.mm(cov_u,self.beta_u))*(self.meas_lat(idx_meas))*(self.time_lat(idx_t)+torch.mm(cov_w,self.beta_w))).sum(1) \n",
    "        return(pred)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TensorFactDataset(Dataset):\n",
    "    def __init__(self,csv_file_serie=\"lab_short_tensor_train.csv\",file_path=\"~/Documents/Data/Full_MIMIC/\",transform=None):\n",
    "        self.lab_short=pd.read_csv(file_path+csv_file_serie)\n",
    "        self.length=len(self.lab_short.index)\n",
    "        self.pat_num=self.lab_short[\"UNIQUE_ID\"].nunique()\n",
    "        self.cov_values=[chr(i) for i in range(ord('A'),ord('A')+18)]\n",
    "        self.time_values=[\"TIME_STAMP\",\"TIME_SQ\"]\n",
    "        #print(self.lab_short.dtypes)\n",
    "        self.tensor_mat=self.lab_short.as_matrix()\n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "    def __getitem__(self,idx):\n",
    "        #ind_sparse=self.sparse_tens.getrow(idx)\n",
    "        \n",
    "        #print(self.lab_short[\"VALUENUM\"].iloc[idx].values)\n",
    "        #return(self.lab_short.iloc[idx].as_matrix())\n",
    "        return(self.tensor_mat[idx,:])\n",
    "        #return([torch.from_numpy(self.lab_short.iloc[idx][[\"UNIQUE_ID\",\"LABEL_CODE\",\"TIME_STAMP\"]].astype('int64').as_matrix()),torch.from_numpy(self.lab_short.iloc[idx][self.cov_values].as_matrix()),torch.from_numpy(self.lab_short.iloc[idx][self.time_values].astype('float64').as_matrix()),torch.tensor(self.lab_short[\"VALUENUM\"].iloc[idx],dtype=torch.double)])\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EPOCH : 0\n",
      "0\n",
      "0.5869932174682617\n",
      "1\n",
      "0.5800011157989502\n",
      "2\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-144-25d7fca266ea>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0mpreds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mindexes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mindexes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcov_u\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcov_w\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m         \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mstarttime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m     91\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     92\u001b[0m         \"\"\"\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     87\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     88\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#With Adam optimizer\n",
    "import time \n",
    "\n",
    "tens_dataset=TensorFactDataset()\n",
    "dataloader = DataLoader(tens_dataset, batch_size=100000,shuffle=True)\n",
    "\n",
    "mod=tensor_fact(n_pat=tens_dataset.pat_num,n_meas=30,n_t=101,n_u=18,n_w=1)\n",
    "mod.double()\n",
    "\n",
    "optimizer=torch.optim.Adam(mod.parameters(), lr=0.001)\n",
    "criterion = nn.MSELoss(size_average=False)#\n",
    "epochs_num=1\n",
    "\n",
    "for epoch in range(epochs_num):\n",
    "    print(\"EPOCH : \"+str(epoch))\n",
    "    for i_batch,sampled_batch in enumerate(dataloader):\n",
    "        \n",
    "        print(i_batch)\n",
    "        starttime=time.time()\n",
    "        indexes=sampled_batch[:,1:4].to(torch.long)\n",
    "        cov_u=sampled_batch[:,4:22]\n",
    "        cov_w=sampled_batch[:,3].unsqueeze(1)\n",
    "        target=sampled_batch[:,-1]\n",
    "        \n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        #preds=mod.forward(sampled_batch[0][:,0],sampled_batch[0][:,1],sampled_batch[0][:,2],sampled_batch[1],sampled_batch[2])\n",
    "       \n",
    "        preds=mod.forward(indexes[:,0],indexes[:,1],indexes[:,2],cov_u,cov_w)\n",
    "        loss=criterion(preds,target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        print(time.time()-starttime)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 3.,  0.,  0., 75.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,\n",
       "        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  9.])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=TensorFactDataset()\n",
    "b=a.__getitem__(3)\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 3,  0,  0], dtype=torch.int32)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "indexes=torch.Tensor(b[:3]).int()\n",
    "indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "csv_file_serie=\"lab_short_tensor.csv\"\n",
    "file_path=\"~/Documents/Data/Full_MIMIC/\"\n",
    "lab_short=pd.read_csv(file_path+csv_file_serie)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'lab_short' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-9aec4f53a443>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m174729\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlab_short\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"VALUENUM\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'lab_short' is not defined"
     ]
    }
   ],
   "source": [
    "idx=[174729,3]\n",
    "a=lab_short.iloc[idx][\"VALUENUM\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 282,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9.0"
      ]
     },
     "execution_count": 282,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lab_short.iloc[3][\"VALUENUM\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=np.array([1,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.,  2.,  3.])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.Tensor(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.float64"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.randn(3).double().dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64515"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.022705276605628894"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lab_dim=3\n",
    "lab_dim*(lab_short[\"UNIQUE_ID\"].nunique()+30+101)/len(lab_short.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.043873822209581304"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(lab_short.index)/(lab_short[\"UNIQUE_ID\"].nunique()*30*101)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7071067811865476"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sqrt(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, Parameter containing:\n",
      "tensor([[ 0.3364,  2.3571],\n",
      "        [-1.5298,  2.7330],\n",
      "        [ 1.1631,  0.3089],\n",
      "        [-2.1276, -0.3360],\n",
      "        [-0.6021, -0.0876],\n",
      "        [ 0.4438, -0.0219],\n",
      "        [-1.8984,  1.8621],\n",
      "        [ 0.1408,  0.7330],\n",
      "        [ 0.3796, -1.3308],\n",
      "        [-0.5094, -1.4694],\n",
      "        [ 0.5451,  0.5626],\n",
      "        [ 1.6695,  1.3612],\n",
      "        [ 1.1912, -0.8589],\n",
      "        [ 0.7064, -1.7086],\n",
      "        [ 0.3451, -0.3606],\n",
      "        [-0.3710, -1.2147],\n",
      "        [ 1.3785, -0.9152],\n",
      "        [ 0.2956, -2.4411]], dtype=torch.float64))\n",
      "(1, Parameter containing:\n",
      "tensor([[ 0.9371,  0.1345]], dtype=torch.float64))\n",
      "(2, Parameter containing:\n",
      "tensor([[ 3.0396e-01,  1.7614e+00],\n",
      "        [ 2.2776e-01,  5.4416e-01],\n",
      "        [-1.0709e+00,  1.0852e+00],\n",
      "        ...,\n",
      "        [-1.2617e-01,  6.6292e-01],\n",
      "        [-6.0991e-01,  2.3274e-01],\n",
      "        [-5.7461e-01, -2.3834e-01]]))\n",
      "(3, Parameter containing:\n",
      "tensor([[-0.8869,  1.0772],\n",
      "        [-0.4788,  1.6747],\n",
      "        [-2.0322, -0.5766],\n",
      "        [ 0.8952, -0.1532],\n",
      "        [-0.9540,  0.7711],\n",
      "        [ 0.2156,  0.3281],\n",
      "        [ 0.6661, -0.6945],\n",
      "        [-0.7517, -1.4523],\n",
      "        [ 0.6853, -0.0883],\n",
      "        [-0.3067,  0.4328],\n",
      "        [-0.6045, -0.6998],\n",
      "        [-0.8664,  1.2287],\n",
      "        [-1.8229,  0.3619],\n",
      "        [ 0.2984, -0.0964],\n",
      "        [ 0.6667, -0.5132],\n",
      "        [-0.6685,  0.5399],\n",
      "        [ 0.7176,  1.6406],\n",
      "        [-0.2538, -0.6193],\n",
      "        [ 1.1974, -0.9206],\n",
      "        [-0.6551,  0.3048],\n",
      "        [-0.7878,  1.8730],\n",
      "        [ 0.1199,  1.1207],\n",
      "        [ 1.3652, -0.7674],\n",
      "        [ 0.6410,  1.8403],\n",
      "        [ 0.2855, -0.3916],\n",
      "        [ 0.8999, -1.4076],\n",
      "        [ 0.0882, -0.6397],\n",
      "        [-0.5374,  0.6909],\n",
      "        [-0.0994,  0.7779],\n",
      "        [ 1.1834, -0.0294]], dtype=torch.float64))\n",
      "(4, Parameter containing:\n",
      "tensor([[ 1.9227, -0.1298],\n",
      "        [ 0.8106, -0.0223],\n",
      "        [ 1.6437,  0.2469],\n",
      "        [ 0.9762, -1.7706],\n",
      "        [-1.7625,  1.2683],\n",
      "        [ 0.7756,  1.2466],\n",
      "        [ 0.9856,  0.1665],\n",
      "        [-0.1647, -0.6369],\n",
      "        [ 0.5762,  1.0341],\n",
      "        [ 1.0482,  0.2388],\n",
      "        [-0.8288, -0.2067],\n",
      "        [ 1.2218, -1.7234],\n",
      "        [-0.0155,  0.9846],\n",
      "        [-0.0986, -0.2208],\n",
      "        [-0.3046, -1.9771],\n",
      "        [-2.0706,  0.4858],\n",
      "        [-1.4847, -1.6744],\n",
      "        [-0.4276,  1.6520],\n",
      "        [ 0.7825, -0.1188],\n",
      "        [ 1.5570,  0.1737],\n",
      "        [ 0.8409, -1.1478],\n",
      "        [-1.0708, -1.2359],\n",
      "        [ 0.5540,  0.2549],\n",
      "        [-0.2678, -0.9698],\n",
      "        [-1.7953,  1.7369],\n",
      "        [ 0.3127, -0.9351],\n",
      "        [ 0.1082, -0.2460],\n",
      "        [ 0.2146, -2.0104],\n",
      "        [-1.6838,  0.7142],\n",
      "        [-0.0240,  0.7360],\n",
      "        [ 0.0345,  0.2556],\n",
      "        [-0.1985,  0.1042],\n",
      "        [-0.4134,  1.5730],\n",
      "        [-0.1408, -0.9498],\n",
      "        [-0.3701,  0.7804],\n",
      "        [-1.7473,  0.0384],\n",
      "        [ 1.6873, -1.4604],\n",
      "        [-0.1670,  0.5167],\n",
      "        [ 1.4892,  0.0972],\n",
      "        [ 0.6154,  1.3571],\n",
      "        [ 1.5896,  0.4510],\n",
      "        [-1.4406,  0.1019],\n",
      "        [ 0.2754, -0.3581],\n",
      "        [ 0.2209, -2.0199],\n",
      "        [ 1.3804, -0.8279],\n",
      "        [ 0.7851,  0.0908],\n",
      "        [ 0.1146,  0.3797],\n",
      "        [ 0.7980, -2.6698],\n",
      "        [-0.9831,  0.4521],\n",
      "        [ 0.3307, -1.5532],\n",
      "        [-0.2761,  0.0792],\n",
      "        [ 0.1129, -0.1825],\n",
      "        [-0.2690,  0.3082],\n",
      "        [ 0.8628, -1.1996],\n",
      "        [ 0.3259,  0.3020],\n",
      "        [ 0.2457, -0.0779],\n",
      "        [ 0.9737,  1.9354],\n",
      "        [-0.9634, -2.0090],\n",
      "        [ 0.2293, -1.8193],\n",
      "        [ 0.7783, -0.1275],\n",
      "        [ 1.0896,  0.6622],\n",
      "        [ 0.3740,  0.6133],\n",
      "        [-1.2698, -0.0423],\n",
      "        [ 0.6214,  0.6424],\n",
      "        [-3.3794, -1.1840],\n",
      "        [ 0.2349, -1.4366],\n",
      "        [-1.3056,  0.8200],\n",
      "        [ 1.8413,  0.1455],\n",
      "        [ 0.6322, -0.4988],\n",
      "        [ 0.4381,  0.4820],\n",
      "        [ 0.4135,  0.2526],\n",
      "        [ 2.3152,  0.5714],\n",
      "        [-0.8804, -0.0230],\n",
      "        [-0.5558, -1.5065],\n",
      "        [ 1.2159, -0.6327],\n",
      "        [-0.2200,  0.3854],\n",
      "        [-0.6089,  0.5881],\n",
      "        [ 0.4178,  0.8282],\n",
      "        [ 0.1608,  0.2542],\n",
      "        [ 1.5019, -1.6554],\n",
      "        [-0.8679, -1.1868],\n",
      "        [ 1.1505,  1.0765],\n",
      "        [-0.3830,  0.6154],\n",
      "        [ 0.3615,  1.0490],\n",
      "        [ 0.8038,  0.4391],\n",
      "        [-0.3449,  0.1777],\n",
      "        [-0.1672,  0.3899],\n",
      "        [-0.0767, -1.5128],\n",
      "        [ 0.2231, -0.9346],\n",
      "        [ 1.2055,  1.0357],\n",
      "        [-1.6525,  0.4283],\n",
      "        [-0.0784, -0.2126],\n",
      "        [ 0.3885, -1.4838],\n",
      "        [ 0.6038,  0.1339],\n",
      "        [-0.3390,  0.6118],\n",
      "        [ 0.3033, -1.5409],\n",
      "        [-0.1413, -1.8203],\n",
      "        [ 0.0502, -1.8374],\n",
      "        [ 0.3630, -0.9576],\n",
      "        [ 0.8562, -0.2544],\n",
      "        [-0.9446,  0.7180]], dtype=torch.float64))\n"
     ]
    }
   ],
   "source": [
    "mod=tensor_fact(n_pat=tens_dataset.pat_num,n_meas=30,n_t=101,n_u=18,n_w=1)\n",
    "for i in enumerate(mod.parameters()):\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.5258, -0.0692, -0.4372, -0.6689,  1.9913],\n",
       "        [-1.4903, -2.2111,  0.9129, -0.2722,  1.8234],\n",
       "        [-0.1660, -0.2575,  0.6380, -0.7633, -0.9242],\n",
       "        [ 2.7356,  0.2454, -1.1247,  0.4840,  1.6507],\n",
       "        [ 1.3054,  0.3608, -0.3589,  1.2875,  0.3299]], dtype=torch.float64)"
      ]
     },
     "execution_count": 148,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a=nn.Embedding(n_meas,l_dim).double()\n",
    "a.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [],
   "source": [
    "a.weight=nn.Parameter(0.5*torch.randn([5,5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.6721, -0.3864,  0.6945, -0.6208, -0.3603],\n",
       "        [-0.7730,  0.2382, -0.3704, -0.0439, -0.5443],\n",
       "        [ 0.2399,  0.3886, -0.3603,  0.0543,  1.0006],\n",
       "        [ 0.9537,  0.1054, -0.3608,  0.0165,  0.5242],\n",
       "        [-0.1822, -0.3476,  1.1672, -0.0865, -0.5477]])"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Embedding(5, 5)"
      ]
     },
     "execution_count": 140,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.double()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
